# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast clipping function for `tfm.nlp.layers.EinsumDense`."""

from collections.abc import Mapping, Sequence
from typing import Any, Optional
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import einsum_utils


def einsum_layer_computation(
    layer_instance: tf.keras.layers.EinsumDense,
    input_args: Sequence[Any],
    input_kwargs: Mapping[str, Any],
    tape: tf.GradientTape,
    num_microbatches: Optional[tf.Tensor] = None,
) -> type_aliases.RegistryFunctionOutput:
  """Registry function for `tf.keras.layers.EinsumDense`.

  For the technical details, see the documentation of
  `einsum_utils.compute_fast_einsum_gradient_norm()`.

  Args:
    layer_instance: A `tf.keras.layers.EinsumDense` instance.
    input_args: See `dense_layer_computation()` in `dense.py`.
    input_kwargs: See `dense_layer_computation()` in `dense.py`.
    tape: See `dense_layer_computation()` in `dense.py`.
    num_microbatches: See `dense_layer_computation()` in `dense.py`.

  Returns:
    See `dense_layer_computation()` in `dense.py`.
  """
  if input_kwargs:
    raise ValueError("EinsumDense layer calls should not receive kwargs.")
  del input_kwargs
  if len(input_args) != 1:
    raise ValueError("Only layer inputs of length 1 are permitted.")
  orig_activation = layer_instance.activation
  # Some activation functions may not apply a transform to the elements of the
  # output individually (which is needed for the fast clipping trick to work).
  # To avoid this case, we watch the variables that are only generated by the
  # linear transformation of the `EinsumDense` layer instance.
  layer_instance.activation = None
  base_vars = layer_instance(*input_args)
  tape.watch(base_vars)
  layer_instance.activation = orig_activation
  outputs = orig_activation(base_vars) if orig_activation else base_vars

  def sqr_norm_fn(grads):
    return einsum_utils.compute_fast_einsum_squared_gradient_norm(
        layer_instance.equation,
        input_args[0],
        grads,
        layer_instance.bias_axes,
        num_microbatches,
    )

  return base_vars, outputs, sqr_norm_fn
